{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Creating a new press\n",
    "\n",
    "In this guide, we will walk you through the process of creating a new press."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from contextlib import contextmanager\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from transformers import pipeline\n",
    "\n",
    "from kvpress import BasePress, KnormPress, ScorerPress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.\n",
      "Device set to use cuda:0\n"
     ]
    }
   ],
   "source": [
    "# Load pipeline\n",
    "\n",
    "device = \"cuda:0\"\n",
    "ckpt = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
    "attn_implementation = \"flash_attention_2\"\n",
    "pipe = pipeline(\"kv-press-text-generation\", model=ckpt, device=device, torch_dtype=\"auto\", model_kwargs={\"attn_implementation\":attn_implementation})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "\n",
    "context = \"In this step-by-step guide, you will learn how to create a new press in kvpress !\"\n",
    "question = \"\\nWhat is the purpose of this guide?\"\n",
    "tokens = pipe.tokenizer(context, return_tensors=\"pt\").to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Understanding how press work"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A press registers a forward hook to each attention layer during the pre-filling phase.  Immediately after the forward pass, the hook is called, and it compresses the KV cache."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n",
      "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cache shape w/o press: torch.Size([1, 2, 20, 128])\n",
      "Cache shape w/ press:  torch.Size([1, 2, 15, 128])\n",
      "\n",
      "The purpose of this step-by-step guide is to provide instructions on how to create a new press in kvpress. The guide is designed to help users understand the process of setting up a new press in the kvpress platform.\n"
     ]
    }
   ],
   "source": [
    "compression_ratio = 0.25\n",
    "press = KnormPress(compression_ratio)\n",
    "\n",
    "with torch.no_grad():\n",
    "    outputs_without_press = pipe.model(**tokens, output_hidden_states=True)\n",
    "\n",
    "with torch.no_grad(), press(pipe.model):\n",
    "    output_with_press = pipe.model(**tokens)\n",
    "\n",
    "print(f\"Cache shape w/o press: {outputs_without_press.past_key_values[0][0].shape}\")\n",
    "print(f\"Cache shape w/ press:  {output_with_press.past_key_values[0][0].shape}\\n\")\n",
    "\n",
    "# The `KVPressTextGenerationPipeline` simply applies the `press` as above on the context tokens (see `_forward` method for more details).\n",
    "print(pipe(context, question=question, press=press)[\"answer\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Creating your own press\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 Updating the `score` method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The easiest way to create a new press is to create a class that inherits from `ScorerPress` and implement a `score` method that computes the score for each key-value pair.\n",
    "\n",
    "The arguments of the `score` method are obtained from the forward hook:\n",
    "- `module`: the attention layer\n",
    "- `hidden_states`: the input of the attention layer\n",
    "- `keys` and `values`: the key-value pairs from the attention layer\n",
    "- `attentions`: the attention weights, only available with `attn_implementation=\"eager\"`\n",
    "\n",
    "In this first example, we will reproduce the `KnormPress` where the score of a key-value pair is simply the opposite of the norm of the key vector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyKnormPress(ScorerPress):\n",
    "    def score(\n",
    "        self,\n",
    "        module: nn.Module,\n",
    "        hidden_states: torch.Tensor,\n",
    "        keys: torch.Tensor,\n",
    "        values: torch.Tensor,\n",
    "        attentions: torch.Tensor,\n",
    "        kwargs,\n",
    "    ) -> torch.Tensor:\n",
    "\n",
    "        scores = -keys.norm(dim=-1)\n",
    "\n",
    "        # For demonstration, we show some details on the shape for the first layer\n",
    "        if module.layer_idx == 0:\n",
    "            print(f\"module: {module}\")\n",
    "            print(f\"Number of key value heads: {module.config.num_key_value_heads}\")\n",
    "            print(f\"Sequence length: {hidden_states.shape[1]}\")\n",
    "            print()\n",
    "            print(f\"hidden_states shape: {hidden_states.shape}\")\n",
    "            print(f\"keys shape:          {keys.shape}\") # shape (bhnd)\n",
    "            print(f\"values shape:        {values.shape}\") # shape (bhnd)\n",
    "            print(f\"score shape:         {scores.shape}\") # shape (bhn)\n",
    "            print()\n",
    "        \n",
    "        return scores\n",
    "\n",
    "\n",
    "press = MyKnormPress(compression_ratio)\n",
    "print(pipe(context, question=question, press=press)[\"answer\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 Updating the `compress` method "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `compress` method defined in the `BasePress` contains the core logic of the compression and returns compressed keys and values. For instance, in the `ScorerPress` the `compress` calls the `score` method (which is specific to `ScorerPress`) and prune the key-value pairs based on the scores.\n",
    "\n",
    "The following example will show how it works. We will re-implement the `StreamingLLMPress` in a more compact way."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class MyStreamingLLMPress(BasePress):\n",
    "    n_first: int = 1\n",
    "    n_last: int = 8\n",
    "\n",
    "    def compress(\n",
    "        self,\n",
    "        module: nn.Module,\n",
    "        hidden_states: torch.Tensor,\n",
    "        keys: torch.Tensor,\n",
    "        values: torch.Tensor,\n",
    "        attentions: torch.Tensor,\n",
    "        kwargs: dict,\n",
    "    ) -> tuple[torch.Tensor, torch.Tensor]:\n",
    "\n",
    "        mask = torch.ones(keys.shape[-2], dtype=torch.bool, device=keys.device)\n",
    "        mask[self.n_first : -self.n_last] = False\n",
    "        return keys[:, :, mask, :], values[:, :, mask, :]\n",
    "\n",
    "\n",
    "for n_last in [2, 4, 8]:\n",
    "    press = MyStreamingLLMPress(n_last=n_last)\n",
    "    print(f\"\\nn_last: {n_last}\")\n",
    "    print(f\"Last tokens seen by the model: {pipe.tokenizer.decode(tokens.input_ids[0, -n_last:])}\")\n",
    "    print(f\"Answer: {pipe(context, question=question, press=press)['answer']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that in the `compress` method is itself used in the `forward_hook` method which ensures quantization is handled properly and that the compression is only performed during prefilling. While we don't recommend to change the `forward_hook` method directly, you can still modify it if you need to !"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 Head-wise compression\n",
    "\n",
    "Since 0.2.0, kvpress support head-wise compression, where the KV cache of each head might be compressed by a different compression ratio. \n",
    "\n",
    "To achieve proper head-wise compression, one should implement a new kernel for attention along with a custom cache class. Instead, the current implementation fakes head-wise compression by updating the pruned keys by a fake key so that the output of the attention layer is not affected. This is implemented through `kvpress.attention_patch.patch_attention_functions`.\n",
    "\n",
    "To implement a method that compresses the KV cache head-wise, one should instantiate the `masked_key_indices` as outlined below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "compression_ratio: 0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Answer: The purpose of this step-by-step guide is to provide a comprehensive and easy-to-follow tutorial on how to create a new press in the KVPress platform. The guide is designed to help users understand the process of setting up a new press, including the\n",
      "\n",
      "compression_ratio: 0.25\n",
      "Answer: The purpose of this guide is to provide a step-by-step process for creating a new press in KVPRESS, which is a popular open-source web server. The guide will cover the necessary steps to set up and configure a new press, including installing\n",
      "\n",
      "compression_ratio: 0.9\n",
      "Answer: This guide is not a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a guide. It is a\n"
     ]
    }
   ],
   "source": [
    "@dataclass\n",
    "class RandomHeadPress(BasePress):\n",
    "\n",
    "    compression_ratio: float = 0.0\n",
    "\n",
    "    def compress(self, module, hidden_states, keys, values, attentions, kwargs):\n",
    "        assert keys.shape[0] == 1, \"Only batch size 1 is supported\"\n",
    "        scores = torch.rand(keys.shape[:-1], device=keys.device)\n",
    "        mask = scores < torch.quantile(scores, self.compression_ratio)\n",
    "        module.masked_key_indices = torch.nonzero(mask, as_tuple=True)\n",
    "        \n",
    "        return keys, values\n",
    "\n",
    "for compression_ratio in [0, 0.25, 0.9]:\n",
    "    press = RandomHeadPress(compression_ratio)\n",
    "    print(f\"\\ncompression_ratio: {compression_ratio}\")\n",
    "    print(f\"Answer: {pipe(context, question=question, press=press)['answer']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Contributing to kvpress"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "All presses should be stored in the `presses` directory. Before opening a pull request with your new press, make sure to \n",
    "- register it in the `__init__.py` file of repository\n",
    "- register the press in [default_presses.py](tests/default_presses.py)\n",
    "- update the README"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
